#include "esmd_types.h"

#include "memory.h"
#include "cell.h"
#include "comm.h"
#include <math.h>
#include "nose_hoover.h"
void nh_init(nh_tstat_t *tstat, cellgrid_t *grid, esmd_config_t *conf){
  tstat->tstart = conf->tempstart;
  tstat->tstop = conf->tempstop;
  tstat->tdamp = conf->tempdamp;
  tstat->mtchain = conf->tempchain;
  tstat->ftm2v = conf->ftm2v;
  tstat->mvv2e = conf->mvv2e;
  tstat->boltz = conf->boltz;
  tstat->dt = conf->dt;
  tstat->nc_tchain = conf->tempnc;
  tstat->drag = conf->tempdrag;
  
  tstat->tdrag_factor = 1.0 - (tstat->dt * tstat->tfreq * tstat->drag / tstat->nc_tchain);
  tstat->ncfac = 1.0 / tstat->nc_tchain;
  tstat->tdof = (grid->natoms - 1) * 3;
  tstat->tfactor = tstat->mvv2e / (tstat->tdof * tstat->boltz);
  tstat->tfreq = 1.0 / tstat->tdamp;

  real *eta = tstat->eta;
  real *eta_mass = tstat->eta_mass;
  real *eta_dot = tstat->eta_dot;
  real *eta_dotdot = tstat->eta_dotdot;
  for (int ich = 0; ich <= tstat->mtchain; ich ++) {
    eta[ich] = 0;
    eta_dot[ich] = 0;
    eta_mass[ich] = 0;
    eta_dotdot[ich] = 0;
  }
  real t_target = tstat->tstart;
  
  eta_mass[0] = tstat->tdof * tstat->boltz * tstat->tstart / (tstat->tfreq * tstat->tfreq);
  eta_dotdot[0] = 0;
  for (int ich = 1; ich < tstat->mtchain; ich ++) {
    eta_mass[ich] = tstat->boltz * t_target / (tstat->tfreq * tstat->tfreq);
  }
  for (int ich = 1; ich < tstat->mtchain; ich ++) {
    eta_dotdot[ich] = (eta_mass[ich-1]*eta_dot[ich-1]*eta_dot[ich-1] - tstat->boltz * t_target) / eta_mass[ich];
  }
}

real nh_temp_integrate(cellgrid_t *grid, mpp_t *mpp, nh_tstat_t *tstat, real progress, real ke_current){
  real t_target = tstat->tstart + (tstat->tstop - tstat->tstart) * progress;
  real ke_target = tstat->tdof * tstat->boltz * t_target;
  // real ke_current = nh_get_ke(grid, mpp, tstat);
  real *eta = tstat->eta;
  real *eta_mass = tstat->eta_mass;
  real *eta_dot = tstat->eta_dot;
  real *eta_dotdot = tstat->eta_dotdot;
  eta_mass[0] = ke_target * tstat->tdamp * tstat->tdamp;
  for (int ich = 1; ich < tstat->mtchain; ich ++) {
    eta_mass[ich] = tstat->boltz * t_target * tstat->tdamp * tstat->tdamp;
  }
  if (eta_mass[0] > 0) {
    eta_dotdot[0] = (ke_current - ke_target) / eta_mass[0];
  } else {
    eta_dotdot[0] = 0;
  }
  real factor_v = 1.0;
  for (int iloop = 0; iloop < tstat->nc_tchain; iloop ++) {
    for (int ich = tstat->mtchain - 1; ich >= 0; ich --) {
      real expfac = exp(-tstat->ncfac * tstat->dt * 0.125 * eta_dot[ich+1]);
      eta_dot[ich] = ((eta_dot[ich] * expfac) + eta_dotdot[ich] * tstat->ncfac * tstat->dt * 0.25) * tstat->tdrag_factor * expfac;
    }
    real factor_eta = exp(-tstat->ncfac * tstat->dt * 0.5 * eta_dot[0]);
    factor_v *= factor_eta;
    ke_current *= factor_eta * factor_eta;
    if (eta_mass[0] > 0) {
      eta_dotdot[0] = (ke_current - ke_target) / eta_mass[0];
    } else {
      eta_dotdot[0] = 0;
    }
    for (int ich = 0; ich < tstat->mtchain; ich ++) {
      eta[ich] += tstat->ncfac * tstat->dt * 0.5 * eta_dot[ich];
    }
    real expfac0 = exp(-tstat->ncfac * tstat->dt * 0.125 * eta_dot[1]);
    eta_dot[0] = (eta_dot[0] * expfac0 + eta_dotdot[0] * tstat->ncfac * tstat->dt * 0.25) * expfac0;
    for (int ich = 1; ich < tstat->mtchain; ich ++) {
      real expfac = exp(-tstat->ncfac * tstat->dt * 0.125 * eta_dot[ich + 1]);
      eta_dotdot[ich] = (eta_mass[ich - 1] * eta_dot[ich - 1] * eta_dot[ich - 1] - tstat->boltz * t_target) / eta_mass[ich];
      eta_dot[ich] = (eta_dot[ich] * expfac + eta_dotdot[ich] * tstat->ncfac * tstat->dt * 0.25) * expfac;
    }
  }
  return factor_v;
}